# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import pathlib

import pandas
import pyarrow
import pyarrow.dataset
import pytest
from pandas.testing import assert_frame_equal

from adbc_driver_manager import dbapi


def test_type_objects():
    assert dbapi.NUMBER == pyarrow.int64()
    assert pyarrow.int64() == dbapi.NUMBER

    assert dbapi.STRING == pyarrow.string()
    assert pyarrow.string() == dbapi.STRING

    assert dbapi.STRING != dbapi.NUMBER
    assert dbapi.NUMBER != dbapi.DATETIME
    assert dbapi.NUMBER == dbapi.ROWID


@pytest.mark.sqlite
def test_attrs(sqlite):
    assert sqlite.Warning == dbapi.Warning
    assert sqlite.Error == dbapi.Error
    assert sqlite.InterfaceError == dbapi.InterfaceError
    assert sqlite.DatabaseError == dbapi.DatabaseError
    assert sqlite.DataError == dbapi.DataError
    assert sqlite.OperationalError == dbapi.OperationalError
    assert sqlite.IntegrityError == dbapi.IntegrityError
    assert sqlite.InternalError == dbapi.InternalError
    assert sqlite.ProgrammingError == dbapi.ProgrammingError
    assert sqlite.NotSupportedError == dbapi.NotSupportedError

    with sqlite.cursor() as cur:
        assert cur.arraysize == 1
        assert cur.connection is sqlite
        assert cur.description is None
        assert cur.rowcount == -1


@pytest.mark.sqlite
def test_info(sqlite):
    info = sqlite.adbc_get_info()
    assert set(info.keys()) == {
        "driver_arrow_version",
        "driver_name",
        "driver_version",
        "vendor_name",
        "vendor_version",
    }
    assert info["driver_name"] == "ADBC SQLite Driver"
    assert info["vendor_name"] == "SQLite"


@pytest.mark.sqlite
def test_get_underlying(sqlite):
    assert sqlite.adbc_database
    assert sqlite.adbc_connection
    with sqlite.cursor() as cur:
        assert cur.adbc_statement


@pytest.mark.sqlite
def test_clone(sqlite):
    with sqlite.adbc_clone() as sqlite2:
        with sqlite2.cursor() as cur:
            cur.execute("CREATE TABLE temporary (ints)")
            cur.execute("INSERT INTO temporary VALUES (1)")
        sqlite2.commit()

    with sqlite.cursor() as cur:
        cur.execute("SELECT * FROM temporary")
        assert cur.fetchone() == (1,)


@pytest.mark.sqlite
def test_get_objects(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("CREATE TABLE temporary (ints)")
        cur.execute("INSERT INTO temporary VALUES (1)")
    metadata = (
        sqlite.adbc_get_objects(table_name_filter="temporary").read_all().to_pylist()
    )
    assert len(metadata) == 1
    assert metadata[0]["catalog_name"] == "main"
    schemas = metadata[0]["catalog_db_schemas"]
    assert len(schemas) == 1
    assert schemas[0]["db_schema_name"] == ""
    tables = schemas[0]["db_schema_tables"]
    assert len(tables) == 1
    assert tables[0]["table_name"] == "temporary"
    assert tables[0]["table_type"] == "table"
    assert tables[0]["table_columns"][0]["column_name"] == "ints"
    assert tables[0]["table_columns"][0]["ordinal_position"] == 1
    assert tables[0]["table_constraints"] == []


@pytest.mark.sqlite
def test_get_table_schema(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("CREATE TABLE temporary (ints)")
        cur.execute("INSERT INTO temporary VALUES (1)")
    assert sqlite.adbc_get_table_schema("temporary") == pyarrow.schema(
        [("ints", pyarrow.int64())]
    )


@pytest.mark.sqlite
def test_get_table_types(sqlite):
    assert sqlite.adbc_get_table_types() == ["table", "view"]


class ArrayWrapper:
    def __init__(self, array):
        self.array = array

    def __arrow_c_array__(self, requested_schema=None):
        return self.array.__arrow_c_array__(requested_schema=requested_schema)


class StreamWrapper:
    def __init__(self, stream):
        self.stream = stream

    def __arrow_c_stream__(self, requested_schema=None):
        return self.stream.__arrow_c_stream__(requested_schema=requested_schema)


@pytest.mark.parametrize(
    "data",
    [
        lambda: pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"]),
        lambda: pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]),
        lambda: pyarrow.table(
            [[1, 2], ["foo", ""]], names=["ints", "strs"]
        ).to_reader(),
        lambda: ArrayWrapper(
            pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"])
        ),
        lambda: StreamWrapper(
            pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
        ),
        lambda: pyarrow.table(
            [[1, 2], ["foo", ""]], names=["ints", "strs"]
        ).__arrow_c_stream__(),
    ],
)
@pytest.mark.sqlite
def test_ingest(data, sqlite):
    with sqlite.cursor() as cur:
        cur.adbc_ingest("bulk_ingest", data())

        with pytest.raises(dbapi.Error):
            cur.adbc_ingest("bulk_ingest", data())

        cur.adbc_ingest("bulk_ingest", data(), mode="append")

        with pytest.raises(dbapi.Error):
            cur.adbc_ingest("nonexistent", data(), mode="append")

        with pytest.raises(ValueError):
            cur.adbc_ingest("bulk_ingest", data(), mode="invalid")

    with sqlite.cursor() as cur:
        cur.execute("SELECT * FROM bulk_ingest")
        assert cur.fetchone() == (1, "foo")
        assert cur.fetchone() == (2, "")
        assert cur.fetchone() == (1, "foo")
        assert cur.fetchone() == (2, "")


@pytest.mark.sqlite
def test_partitions(sqlite):
    with pytest.raises(dbapi.NotSupportedError):
        with sqlite.cursor() as cur:
            cur.adbc_execute_partitions("SELECT 1")


@pytest.mark.sqlite
def test_query_fetch_py(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("SELECT 1, 'foo' AS foo, 2.0")
        assert cur.description == [
            ("1", dbapi.NUMBER, None, None, None, None, None),
            ("foo", dbapi.STRING, None, None, None, None, None),
            ("2.0", dbapi.NUMBER, None, None, None, None, None),
        ]
        assert cur.rownumber == 0
        assert cur.fetchone() == (1, "foo", 2.0)
        assert cur.rownumber == 1
        assert cur.fetchone() is None

        cur.execute("SELECT 1, 'foo', 2.0")
        assert cur.fetchmany() == [(1, "foo", 2.0)]
        assert cur.fetchmany() == []

        cur.execute("SELECT 1, 'foo', 2.0")
        assert cur.fetchall() == [(1, "foo", 2.0)]
        assert cur.fetchall() == []

        cur.execute("SELECT 1, 'foo', 2.0")
        assert list(cur) == [(1, "foo", 2.0)]


@pytest.mark.sqlite
def test_query_fetch_arrow(sqlite):
    with sqlite.cursor() as cur:
        with pytest.raises(sqlite.ProgrammingError):
            cur.fetch_arrow()

        cur.execute("SELECT 1, 'foo' AS foo, 2.0")
        capsule = cur.fetch_arrow().__arrow_c_stream__()
        reader = pyarrow.RecordBatchReader._import_from_c_capsule(capsule)
        assert reader.read_all() == pyarrow.table(
            {
                "1": [1],
                "foo": ["foo"],
                "2.0": [2.0],
            }
        )

        with pytest.raises(sqlite.ProgrammingError):
            cur.fetch_arrow()


@pytest.mark.sqlite
def test_query_fetch_arrow_3543(sqlite):
    # Regression test for https://github.com/apache/arrow-adbc/issues/3543
    with sqlite.cursor() as cur:
        cur.execute("SELECT 1, 'foo' AS foo, 2.0")

        # This should not consume the result
        assert cur.description == [
            ("1", dbapi.NUMBER, None, None, None, None, None),
            ("foo", dbapi.STRING, None, None, None, None, None),
            ("2.0", dbapi.NUMBER, None, None, None, None, None),
        ]

        capsule = cur.fetch_arrow().__arrow_c_stream__()
        reader = pyarrow.RecordBatchReader._import_from_c_capsule(capsule)
        assert reader.read_all() == pyarrow.table(
            {
                "1": [1],
                "foo": ["foo"],
                "2.0": [2.0],
            }
        )


@pytest.mark.sqlite
def test_query_fetch_arrow_table(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("SELECT 1, 'foo' AS foo, 2.0")
        assert cur.fetch_arrow_table() == pyarrow.table(
            {
                "1": [1],
                "foo": ["foo"],
                "2.0": [2.0],
            }
        )


@pytest.mark.sqlite
def test_query_fetch_df(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("SELECT 1, 'foo' AS foo, 2.0")
        assert_frame_equal(
            cur.fetch_df(),
            pandas.DataFrame(
                {
                    "1": [1],
                    "foo": ["foo"],
                    "2.0": [2.0],
                }
            ),
        )


@pytest.mark.sqlite
@pytest.mark.parametrize(
    "parameters",
    [
        (1.0, 2),
        pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
        pyarrow.table([[1.0], [2]], names=["float", "int"]),
        ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])),
        StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])),
    ],
)
def test_execute_parameters(sqlite, parameters):
    with sqlite.cursor() as cur:
        cur.execute("SELECT ? + 1, ?", parameters)
        assert cur.fetchall() == [(2.0, 2)]


@pytest.mark.sqlite
def test_execute_parameters_name(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("SELECT @a + 1, @b", {"@b": 2, "@a": 1})
        assert cur.fetchall() == [(2, 2)]

        # Ensure the state of the cursor isn't affected
        cur.execute("SELECT ?2 + 1, ?1", [2, 1])
        assert cur.fetchall() == [(2, 2)]

        cur.execute("SELECT @a + 1, @b + @b", {"@b": 2, "@a": 1})
        assert cur.fetchall() == [(2, 4)]

        data = pyarrow.record_batch([[1.0], [2]], names=["float", "int"])
        cur.adbc_ingest("ingest_tester", data)
        cur.execute("SELECT * FROM ingest_tester")
        assert cur.fetchall() == [(1.0, 2)]


@pytest.mark.sqlite
def test_executemany_parameters_name(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("CREATE TABLE executemany_params (a, b)")

        cur.executemany(
            "INSERT INTO executemany_params VALUES (@a, @b)",
            [{"@b": 2, "@a": 1}, {"@b": 3, "@a": 2}],
        )
        cur.executemany(
            "INSERT INTO executemany_params VALUES (?, ?)", [(3, 4), (4, 5)]
        )

        cur.execute("SELECT * FROM executemany_params ORDER BY a ASC")
        assert cur.fetchall() == [(1, 2), (2, 3), (3, 4), (4, 5)]


@pytest.mark.sqlite
@pytest.mark.parametrize(
    "parameters",
    [
        [(1, "a"), (3, None)],
        pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
        pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
        pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0],
        ArrayWrapper(
            pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"])
        ),
        StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])),
        ((x, y) for x, y in ((1, "a"), (3, None))),
    ],
)
def test_executemany_parameters(sqlite, parameters):
    with sqlite.cursor() as cur:
        cur.execute("DROP TABLE IF EXISTS executemany")
        cur.execute("CREATE TABLE executemany (int, str)")
        cur.executemany("INSERT INTO executemany VALUES (? * 2, ?)", parameters)
        cur.execute("SELECT * FROM executemany ORDER BY int ASC")
        assert cur.fetchall() == [(2, "a"), (6, None)]


@pytest.mark.sqlite
@pytest.mark.parametrize(
    "parameters",
    [
        [],
        pyarrow.record_batch([[]], schema=pyarrow.schema([("v", pyarrow.int64())])),
        pyarrow.table([[]], schema=pyarrow.schema([("v", pyarrow.int64())])),
    ],
)
def test_executemany_empty(sqlite, parameters):
    # Regression test for https://github.com/apache/arrow-adbc/issues/3319
    with sqlite.cursor() as cur:
        # With an empty sequence, it should be the same as not executing the
        # query at all.
        cur.execute("DROP TABLE IF EXISTS executemany")
        cur.execute("CREATE TABLE executemany (v)")
        cur.executemany("INSERT INTO executemany VALUES (?)", parameters)
        cur.execute("SELECT * FROM executemany")
        assert cur.fetchall() == []


@pytest.mark.sqlite
def test_executemany_none(sqlite):
    # Regression test for https://github.com/apache/arrow-adbc/issues/3319
    with sqlite.cursor() as cur:
        # With None, it should be the same as executing the query once.
        cur.execute("DROP TABLE IF EXISTS executemany")
        cur.execute("CREATE TABLE executemany (v)")
        with pytest.raises(sqlite.Error):
            cur.executemany("INSERT INTO executemany VALUES (?)", None)


@pytest.mark.sqlite
def test_query_substrait(sqlite):
    with sqlite.cursor() as cur:
        with pytest.raises(dbapi.NotSupportedError):
            cur.execute(b"Substrait plan")


@pytest.mark.sqlite
def test_executemany(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("CREATE TABLE foo (a, b)")
        cur.executemany(
            "INSERT INTO foo VALUES (?, ?)",
            [
                (1, 2),
                (3, 4),
                (5, 6),
            ],
        )
        cur.execute("SELECT COUNT(*) FROM foo")
        assert cur.fetchone() == (3,)
        cur.execute("SELECT * FROM foo ORDER BY a ASC")
        assert cur.rownumber == 0
        assert next(cur) == (1, 2)
        assert cur.rownumber == 1
        assert next(cur) == (3, 4)
        assert cur.rownumber == 2
        assert next(cur) == (5, 6)


@pytest.mark.sqlite
def test_fetch_record_batch(sqlite):
    dataset = [
        [1, 2],
        [3, 4],
        [5, 6],
        [7, 8],
        [9, 10],
    ]
    with sqlite.cursor() as cur:
        cur.execute("CREATE TABLE foo (a, b)")
        cur.executemany(
            "INSERT INTO foo VALUES (?, ?)",
            dataset,
        )
        cur.execute("SELECT * FROM foo")
        rbr = cur.fetch_record_batch()
        assert rbr.read_pandas().values.tolist() == dataset


@pytest.mark.sqlite
def test_fetch_empty(sqlite):
    with sqlite.cursor() as cur:
        cur.execute("CREATE TABLE foo (bar)")
        cur.execute("SELECT * FROM foo")
        assert cur.fetchall() == []


@pytest.mark.sqlite
def test_reader(sqlite, tmp_path) -> None:
    # Regression test for https://github.com/apache/arrow-adbc/issues/1523
    with sqlite.cursor() as cur:
        cur.execute("SELECT 1")
        reader = cur.fetch_record_batch()
        pyarrow.dataset.write_dataset(reader, tmp_path, format="parquet")


@pytest.mark.sqlite
def test_prepare(sqlite):
    with sqlite.cursor() as cur:
        schema = cur.adbc_prepare("SELECT 1")
        assert schema == pyarrow.schema([])

        schema = cur.adbc_prepare("SELECT 1 + ?")
        assert schema == pyarrow.schema([("0", "null")])

        cur.execute("SELECT 1 + ?", (1,))
        assert cur.fetchone() == (2,)


@pytest.mark.sqlite
def test_close_warning(sqlite):
    with pytest.warns(
        ResourceWarning,
        match=r"A adbc_driver_manager.dbapi.Cursor was not explicitly close\(\)d",
    ):
        cur = sqlite.cursor()
        del cur

    with pytest.warns(
        ResourceWarning,
        match=r"A adbc_driver_manager.dbapi.Connection was not explicitly close\(\)d",
    ):
        conn = dbapi.connect(driver="adbc_driver_sqlite")
        del conn


def _execute_schema(cursor):
    try:
        cursor.adbc_execute_schema("select 1")
    except dbapi.NotSupportedError:
        pass


@pytest.mark.sqlite
@pytest.mark.parametrize(
    "op",
    [
        pytest.param(lambda cursor: cursor.execute("SELECT 1"), id="execute"),
        pytest.param(
            lambda cursor: cursor.executemany("SELECT ?", [[1]]), id="executemany"
        ),
        pytest.param(
            lambda cursor: cursor.adbc_ingest(
                "test_release",
                pyarrow.table([[1]], names=["ints"]),
                mode="create_append",
            ),
            id="ingest",
        ),
        pytest.param(_execute_schema, id="execute_schema"),
        pytest.param(lambda cursor: cursor.adbc_prepare("select 1"), id="prepare"),
        pytest.param(
            lambda cursor: cursor.executescript("select 1"), id="executescript"
        ),
    ],
)
def test_release(sqlite, op) -> None:
    # Regression test. Ensure that subsequent operations free results of
    # earlier operations.
    with sqlite.cursor() as cur:
        cur.execute("select 1")
        # Do _not_ fetch the data so it is never imported.
        assert cur._results._handle.is_valid
        handle = cur._results._handle

        op(cur)
        if handle:
            # The original handle (if it exists) should have been released
            assert not handle.is_valid


def test_driver_path():
    with pytest.raises(
        dbapi.ProgrammingError,
        match="(dlopen|LoadLibraryExW).*failed:",
    ):
        with dbapi.connect(driver=pathlib.Path("/tmp/thisdriverdoesnotexist")):
            pass


@pytest.mark.sqlite
def test_dbapi_extensions(sqlite):
    with sqlite.execute("SELECT ?", (1,)) as cur:
        assert cur.fetchone() == (1,)
        assert cur.fetchone() is None

        assert cur.execute("SELECT 2").fetchall() == [(2,)]

    with sqlite.cursor() as cur:
        assert cur.execute("SELECT 1").fetchall() == [(1,)]
        assert cur.execute("SELECT 42").fetchall() == [(42,)]


@pytest.mark.sqlite
def test_connect(tmp_path: pathlib.Path, monkeypatch) -> None:
    with dbapi.connect(driver="adbc_driver_sqlite") as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT 1")
            assert cur.fetchone() == (1,)

    # https://github.com/apache/arrow-adbc/issues/3517: allow positional
    # argument
    with dbapi.connect("adbc_driver_sqlite") as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT 1")
            assert cur.fetchone() == (1,)

    # https://github.com/apache/arrow-adbc/issues/3517: allow URI argument
    db = tmp_path / "test.db"
    with dbapi.connect("adbc_driver_sqlite", db.as_uri()) as conn:
        with conn.cursor() as cur:
            cur.execute("CREATE TABLE foo (a)")
            cur.execute("INSERT INTO foo VALUES (1)")
        conn.commit()

    with dbapi.connect(driver="adbc_driver_sqlite", uri=db.as_uri()) as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM foo")
            assert cur.fetchone() == (1,)

    monkeypatch.setenv("ADBC_DRIVER_PATH", tmp_path)
    with (tmp_path / "foobar.toml").open("w") as f:
        f.write(
            """
[Driver]
shared = "adbc_driver_foobar"
        """
        )
    # Just check that the driver gets detected and loaded (should fail)
    with pytest.raises(dbapi.ProgrammingError, match="NOT_FOUND"):
        with dbapi.connect("foobar://localhost:5439"):
            pass
